// Copyright 2024 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

syntax = "proto2";

package smith;

// Proto on the model setting and input data sets of dual encoder SMITH/BERT.

// Proto to specify train/eval datasets and train/eval settings.
// Next Available Field: 13
message TrainEvalConfig {
  // File patterns for train set, separated by commas if you have multiple
  // files. This field is required.
  optional string input_file_for_train = 1;

  // File patterns for eval set, separated by commas if you have multiple
  // files.
  optional string input_file_for_eval = 2;

  // Total batch size for training.
  optional int32 train_batch_size = 4 [default = 32];

  // Total batch size for evaluation.
  optional int32 eval_batch_size = 5 [default = 32];

  // Total batch size for prediction.
  optional int32 predict_batch_size = 6 [default = 32];

  // Maximum number of eval steps.
  // This should be set according to the size of eval data. During model
  // pre-training, we can also use a part of training data for evaluation.
  optional int32 max_eval_steps = 7 [default = 100];

  // How often to save the model checkpoint.
  optional int32 save_checkpoints_steps = 8 [default = 1000];

  // How many steps to make in each estimator call.
  optional int32 iterations_per_loop = 9 [default = 1000];

  // This is set to true if we awalys want to evaluate the model with the eval
  // or test data even in the pre-train mode, so that we know whether the model
  // overfits the training data.
  optional bool eval_with_eval_data = 10 [default = true];

  // The weight to compensate when we have more negative examples.
  optional float neg_to_pos_example_ratio = 12 [default = 1.0];
}

// Configuration for BERT-based or SMITH-based encoder.
// Next Available Field: 18
message EncoderConfig {
  // The name of the model.
  optional string model_name = 12 [default = "smith_dual_encoder"];

  // Which pretrained checkpoint to use. This field is required for fine-tuning.
  optional string init_checkpoint = 1;

  // Which prediction checkpoint to use for model prediction process.
  optional string predict_checkpoint = 2;

  // Where is the bert config file.
  optional string bert_config_file = 3;

  // Where is the document level bert config file, which is only used in the
  // the SMITH model.
  optional string doc_bert_config_file = 4;

  // Where is the vocab file.
  optional string vocab_file = 5;

  // This is only used for the BERT model. The maximum total input sequence
  // length after tokenization. Sequences longer than this will be truncated,
  // and sequences shorter than this will be padded. Normally, this should be
  // no larger than the one used in pretraining. This should be matched with
  // the data generation settings.
  optional int32 max_seq_length = 6 [default = 32];

  // Maximum number of masked LM predictions per sequence. Note that for the
  // SMITH model, the maximum number of masked LM predictions per document is
  // max_doc_length_by_sentence * max_predictions_per_seq.
  optional int32 max_predictions_per_seq = 7 [default = 5];

  // This is only used for the SMITH model. The maximum number of tokens in a
  // sentence.
  optional int32 max_sent_length_by_word = 8 [default = 32];

  // This is only used for the SMITH model. The maximum number of sentences in
  // a document.
  optional int32 max_doc_length_by_sentence = 9 [default = 64];

  // This is only used for the SMITH model. The number of looped sentences in a
  // document to control the used TPU memory. This number should be shorter
  // than the setting of max_doc_length_by_sentence.
  optional int32 loop_sent_number_per_doc = 10 [default = 64];

  // This is only used for the SMITH model. Whether update the parameters in
  // the sentence level Tranformers of the SMITH model.
  optional bool sent_bert_trainable = 11 [default = true];

  // This is only used for the SMITH model. The maximum number of sentences to
  // be masked in each document.
  optional int32 max_masked_sent_per_doc = 14 [default = 2];

  // This is only used for the SMITH model. If true, add the masked sentence LM
  // loss into the total training loss.
  optional bool use_masked_sentence_lm_loss = 15 [default = false];

  // The number of different labels in classification task.
  optional int32 num_labels = 13 [default = 2];

  // The type of document representation combing mode. It can be normal,
  // sum_concat, mean_concat or attention.
  optional string doc_rep_combine_mode = 16 [default = "normal"];

  // The size of the attention vector in the attention layer for combining the
  // sentence level representations to generate the document level
  // representations.
  optional int32 doc_rep_combine_attention_size = 17 [default = 256];
}

// Configuration for a loss function.
// Next Available Field: 2
message LossConfig {
  // Hyperparameters for the loss function.
  // The amplifier to increase the logits value, so that sigmoid(logits) is
  // closer to 0 or 1. The default value is 6.0.
  optional float similarity_score_amplifier = 1 [default = 6.0];
}

// Next Available Field: 4
message DualEncoderConfig {
  // Config for the BERT/SMITH based dual encoder.

  optional EncoderConfig encoder_config = 1;

  // This field must be set to supply the train/eval data.
  optional TrainEvalConfig train_eval_config = 2;

  // Config for optimization, this field is required.
  optional LossConfig loss_config = 3;
}
